"""
Evaluate the part tree's semantics
"""

import imageio
import torch
import numpy as np
import os, json
from tqdm import tqdm
import torch
import argparse
import sys
from PIL import Image

CODE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(CODE_DIR)

from openclip_utils import OpenCLIPNetwork, OpenCLIPNetworkConfig
from tree import PartTree, save_tree, load_tree

try:
    import open_clip
except ImportError:
    assert False, "open_clip is not installed, install it with `pip install open-clip-torch`"


def label_tree_clip(model: OpenCLIPNetwork, tree: PartTree):
    """
    Embed the tree using CLIP. This is the baseline method.
    """
    # Crop the image by bbox -> pad to 224x224
    tree.query_preprocess(crop=True, pad=True, resize=(224, 224))

    nodes = tree.get_nodes()
    tiles = []
    for node in nodes:
        img = node.query_image
        assert img is not None, f'Node {node} has no query image'
        tiles.append(img)
    tiles = np.stack(tiles, axis=0)
    tiles = (torch.from_numpy(tiles.astype("float32")).permute(0,3,1,2) / 255.0).to('cuda')

    clip_embed = model.encode_image(tiles)
    clip_embed /= clip_embed.norm(dim=-1, keepdim=True)
    clip_embed = clip_embed.detach().cpu().half()

    for idx, node in enumerate(nodes):
        node.set_embed(clip_embed[idx])

def embed_captions(model: OpenCLIPNetwork, tree: PartTree):
    """
    If tree already has captions, embed the captions in text encoder
    """
    nodes = tree.get_nodes()

    texts = []
    # tiles = []
    for node in nodes:
        caption = node.caption
        assert caption, f'Node {node} has no caption'
        text = caption.split(':', 1)[0]
        texts.append(text)

        # img = node.query_image
        # assert img is not None, f'Node {node} has no query image'
        # tiles.append(img)

    # Also get image embed and merge
    # tiles = np.stack(tiles, axis=0)
    # tiles = (torch.from_numpy(tiles.astype("float32")).permute(0,3,1,2) / 255.0).to('cuda')
    # img_embed = model.encode_image(tiles)
    # img_embed /= img_embed.norm(dim=-1, keepdim=True)

    clip_embed = model.encode_text(texts)
    clip_embed /= clip_embed.norm(dim=-1, keepdim=True)

    # clip_embed = clip_embed + img_embed
    # clip_embed /= clip_embed.norm(dim=-1, keepdim=True)
    clip_embed = clip_embed.detach().cpu().half()

    for idx, node in enumerate(nodes):
        node.set_embed(clip_embed[idx])
    
def get_max_across(model: OpenCLIPNetwork, tree: PartTree, target_text: list[str], k=5) -> torch.Tensor:
    """
    For each text input, get the max activated PartTree node
    """
    model.set_positives(target_text)
    nodes = tree.get_nodes()
    embeds = []
    for node in nodes:
        embed = node.embed
        assert embed is not None, f'Node {node} has no semantic embedding'
        embeds.append(embed)
    embeds = torch.stack(embeds).to('cuda') # nodes x 512

    relevancy = model.get_max_across(embeds).detach().cpu() # phrases x nodes
    n_phrases, n_nodes = relevancy.shape

    # For each node, get top k texts
    k = min(k, n_phrases)

    probs, idx = torch.topk(relevancy.T, k) # (nodes x k)

    # Set top k texts + prob as the node caption
    for i, node in enumerate(nodes):
        texts = [target_text[j] for j in idx[i]]
        p = probs[i]

        caption = [f'{text}: {prob:.2f}' for text, prob in zip(texts, p)]
        caption = ', '.join(caption)
        node.set_caption(caption)

    return relevancy

def plot_masks(phrase, mask_pred, mask_gt) -> None:
    import matplotlib.pyplot as plt

    # Create a figure with two subplots side by side
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

    # Plot the first array
    im1 = ax1.imshow(mask_pred.astype(float), cmap='binary')
    ax1.set_title('pred')
    ax1.axis('off')  # Turn off axis numbers

    # Plot the second array
    im2 = ax2.imshow(mask_gt.astype(float), cmap='binary')
    ax2.set_title('GT')
    ax2.axis('off')  # Turn off axis numbers

    # Add colorbars
    plt.colorbar(im1, ax=ax1, fraction=0.046, pad=0.04)
    plt.colorbar(im2, ax=ax2, fraction=0.046, pad=0.04)

    # Adjust the layout and display the plot
    plt.suptitle(phrase)
    plt.tight_layout()
    plt.show()


def calculate_iou(tree, target_text, relevancy, threshold=0.5) -> dict[str, float]:
    """
    Calculate IoUs for each phrase. Return a dictionary
    """
    nodes = tree.get_nodes()
    ious = {}
    for i, phrase in enumerate(target_text):
        activation = relevancy[i] # activation for each node
        activation = activation > threshold

        # Get predicted and GT mask
        mask_pred = np.zeros_like(tree.root.image)
        mask_gt = np.zeros_like(tree.root.image)
        for j, node in enumerate(nodes):
            if activation[j]:
                mask_pred = np.logical_or(mask_pred, node.image)
            if node.gt_caption == phrase:
                mask_gt |= np.logical_or(mask_gt, node.image)

        # Visualize the IoU
        # plot_masks(phrase, mask_pred, mask_gt)

        intersection = np.sum(np.logical_and(mask_gt, mask_pred))
        union = np.sum(np.logical_or(mask_gt, mask_pred))

        # If union is 0, then iou should be 1
        if union == 0:
            iou = 1.0
        else:
            iou = np.sum(intersection) / np.sum(union)
        ious[phrase] = iou

    return ious

def calculate_confusion_matrix(tree, target_text, relevancy, threshold=0.5) -> dict[str, np.ndarray]:
    """
    Calculate the confusion matrix for each phrase. Return a dictionary mapping phrase
    to confusion matrix
    """
    nodes = tree.get_nodes()
    cms = {}
    for i, phrase in enumerate(target_text):
        activation = relevancy[i] # activation for each node
        activation = activation > threshold

        # Get predicted and GT mask
        tp = 0
        tn = 0
        fp = 0
        fn = 0
        for j, node in enumerate(nodes):
            if activation[j] and node.gt_caption == phrase:
                tp += 1
            elif activation[j]:
                fp += 1
            elif node.gt_caption == phrase:
                fn += 1
            else:
                tn += 1

        conf_mat = np.array([[tp, fp], [fn, tn]])
        cms[phrase] = conf_mat

    return cms


def get_masks(tree, model, dataset, gt_path, clip=False, threshold=0.5):
    if clip:
        label_tree_clip(model, tree)
    else:
        embed_captions(model, tree)

    # List of input queries (ground truth)
    target_text = set()
    if dataset.lower() == 'partnet':
        result_json = os.path.join(gt_path, 'result.json')
        with open(result_json, 'r') as fin:
            tree_hier = json.load(fin)[0]

        def extract(data):
            if 'text' in data.keys():
                target_text.add(data['text'])

            if 'children' in data.keys(): # parent node
                for child in data['children']:
                    extract(child)
        extract(tree_hier)
    else:
        target_text = ['dinner', 'hot dog', 'sausage', 'bun', 'mustard', 'plate']
        # target_text = ['bass', 'cymbal', 'stand', 'drum', 'kick', 'snare', 'stool']

    target_text = list(target_text)
    # Get activations. Rank each part.
    relevancy = get_max_across(model, tree, target_text)

    nodes = tree.get_nodes()
    for i, phrase in enumerate(target_text):
        activation = relevancy[i] # activation for each node
        activation = activation > threshold

        # Get predicted and GT mask
        mask_pred = np.zeros_like(tree.root.image)
        for j, node in enumerate(nodes):
            if activation[j]:
                mask_pred = np.logical_or(mask_pred, node.image)

        # Save the mask
        img = np.where(mask_pred, tree.root.image, 0)

        path = os.path.join(tree.output_dir, f"{phrase}{'_CLIP' if clip else ''}.png")
        imageio.imsave(path, img)


def evaluate_iou(tree, model, dataset, gt_path, clip=False, threshold=0.5):
    # Optional: get CLIP embed for each part
    if clip:
        label_tree_clip(model, tree)
    else:
        embed_captions(model, tree)

    # List of input queries (ground truth)
    target_text = set()
    if dataset.lower() == 'partnet':
        result_json = os.path.join(gt_path, 'result.json')
        with open(result_json, 'r') as fin:
            tree_hier = json.load(fin)[0]

        def extract(data):
            if 'text' in data.keys():
                target_text.add(data['text'])

            if 'children' in data.keys(): # parent node
                for child in data['children']:
                    extract(child)

        extract(tree_hier)

    else:
        target_text = ['sharp', 'cutting', 'dangerous part', 'metal', 'plastic', 'hold with hand']

    target_text = list(target_text)

    # Get activations. Rank each part.
    valid_map = get_max_across(model, tree, target_text)

    # Save iou metrics to a txt file
    ious = calculate_iou(tree, target_text, valid_map, threshold)
    m_iou = np.nanmean(np.array(list(ious.values())))

    with open(os.path.join(tree.output_dir, f"iou{'_CLIP' if clip else ''}_{threshold:.2f}.txt"), 'w') as file:
        file.write(f"mIoU: {m_iou}\n")
        for k, v in sorted(ious.items()):
            file.write(f"{k}: {v}\n")

    cms = calculate_confusion_matrix(tree, target_text, valid_map, threshold)
    cm_total = np.sum(np.stack(list(cms.values())), axis=0)

    np.savetxt(os.path.join(tree.output_dir,f"confmat{'_CLIP' if clip else ''}_{threshold:.2f}.txt"), cm_total, fmt="%d")
    # for k, v in sorted(cms.items()):
    #     np.savetxt(os.path.join(tree.output_dir,f"{k}_confmat{'_CLIP' if clip else ''}.txt"), v, fmt="%d")




def main(args):
    """Main execution flow"""
    tree = load_tree(args.tree_path)
    model = OpenCLIPNetwork(OpenCLIPNetworkConfig)

    if args.gt_path:
        evaluate_iou(tree, model, args.dataset, args.gt_path, args.clip, args.threshold)

    # save the masks
    get_masks(tree, model, args.dataset, args.gt_path, args.clip, args.threshold)

    output_file = "tree_activations"
    if args.clip:
        output_file += "_CLIP"
    tree.render_tree(os.path.join(os.path.dirname(args.tree_path), output_file))

    # Render ground truth tree
    if args.gt_path:
        nodes = tree.get_nodes()
        for node in nodes:
            node.set_caption(node.gt_caption)
        tree.render_tree(os.path.join(os.path.dirname(args.tree_path), "tree_gt"))



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--tree_path", type=str, help="Path to part hierarchy tree")
    parser.add_argument("--gt_path", type=str, help="Path to ground truth", default=None)
    parser.add_argument("--dataset", type=str, help="Dataset type", default="other")
    parser.add_argument("--clip", action="store_true", help="Eval via CLIP baseline")
    parser.add_argument("--threshold", type=float, default=0.5)

    args = parser.parse_args()
    main(args)

